Recalibration of Gaussian Neural Networks regression models:

the recalibratiNN package

Carolina Musso

Instituto de Pesquisa e Estatística do DF, Brazil

2024-07-11

A proper introduction

  • Disclaimer: Me, the package and everything else.
    • Academic: Biological invasions, Fire Ecology, Ecotoxicology …
    • Public Sector: Epidemiology, Sampling design and inference…
    • Free time: Bachelor degree in Statistics, Computational statistics, Bayesian methods, Neural Networks and Recalibration.
  • R!
  • Basically, I really wanted to develop a package.

Introduction: Neural Networks nowadays

  • It should be able to quantify its uncertainty.
  • NN can be constructed to produce probabilistic results:
    • Optimized by the log-likelihood.
    • Like any model, it can be miscalibrated.
      • A 95% CI should contain 95% of the true output.
      • \(\mathbb{P}(Y \leq \hat{F_Y}^{-1}(p))= p , \forall ~ p \in [0,1]\)

Note

If optimized by MSE, I will be assuming a normal distribution.

Observing miscalibration

Consider a synthetic data set \((x_i, y_i), i \in (1, ..., n)\) generated by an heteroscedastic non-linear model:

\[ x_i \sim Uniform(1,10)\\ \]

\[ y_i|x_i \sim Normal(\mu = f_1(x_i), \sigma= f_2(x_i)) \\ f_1(x) = 5x^2 + 10 ~; ~ f_2(x) = 30x \]

And the fitted model,

\[ \hat{y}_i = \beta_0 + \beta_1 x_i +\epsilon_i, ~\epsilon_i ~ iid \sim N(0,\sigma) \]

Observing miscalibration

A simple linear regression, just to warm up.

  • Global Coverage: 94.45%.

{fig-align=‘center’ width=960 110%}

PIT - Values

  • Histogram of Probability Integral Transform (PIT) values.

  • Let \(F_Y(y)\) be the CDF of a continuous random variable Y, then:

\[U = F_Y (Y ) ∼ Uniform(0, 1)\]

  • In particular, if \(Y \sim Normal(\mu, \sigma)\):

\[Y = F_Y^{-1} (U) ∼ Normal(\mu, \sigma)\]

Visualizing PIT-values

Recalibration

Available Packages

  • R: probably

  • Python: ml_insights

  • Only global, focused on classification problems, and only applicable in the covariate space.

Method:

  • Torres et al (2024): Calibration across various representations of the covariate space: useful for Artificial Neural Networks (ANNs).

Algorithm

The Package

recalibratiNN package

  • 7 functions & 10 dependencies
Function Description Arguments
PIT_global Calculates PIT values for the entire dataset ycal, yhat, mse
PIT_local Calculates PIT values for each cluster xcal, ycal, yhat, mse, clusters, p_neighbours, PIT
gg_PIT_global Plots PIT values histogram pit, type, fill, alpha, print_p
gg_PIT_local Plots PIT values densities for kmeans clusters pit_local, alpha, linewidth, pal, facet
recalibrate Recalibrates the model yhat_new, space_new, space_cal, pit_values, mse, type, p_neighbours, epsilon

Visualizing miscalibration

Global Calibration

pit <- PIT_global(ycal = y_cal, # true values from calib. set.
                  yhat = y_hat_cal, # predictions for calb. set. 
                  mse  = MSE_cal) # MSE from calibration set. 

gg_PIT_global(pit,
               type = "histogram",
              fill = "steelblue4",
              alpha = 0.8,
              print_p = TRUE
            )

Local Calibration

pit_local <- PIT_local(xcal = x_cal, 
                       ycal = y_cal, 
                       yhat = y_hat_cal, 
                       mse = MSE_cal,
                       clusters = 6,
                       p_neighbours = 0.2,
                       PIT = PIT_global)

gg_PIT_local(pit_local)

Neural Networks

Neural Network example

Data

set.seed(42)   # The Answer to the Ultimate Question of Life, The Universe, and Everything

n <- 10000

x <- cbind(x1 = runif(n, -3, 3),
           x2 = runif(n, -5, 5))

mu_fun <- function(x) {
  abs(x[,1]^3 - 50*sin(x[,2]) + 30)}

mu <- mu_fun(x)
y <- rnorm(n, 
           mean = mu, 
           sd=20*(abs(x[,2]/(x[,1]+ 10))))

split1 <- 0.6
split2 <- 0.8

x_train <- x[1:(split1*n),]
y_train <- y[1:(split1*n)]

x_cal  <- x[(split1*n+1):(n*split2),]
y_cal  <- y[(split1*n+1):(n*split2)]

x_test <- x[(split2*n+1):n,]
y_test  <- y[(split2*n+1):n]

Keras

model_nn <- keras_model_sequential()

model_nn |> 
  layer_dense(input_shape=2,
              units=800,
              use_bias=T,
              activation = "relu",
              kernel_initializer="random_normal",
              bias_initializer = "zeros") %>%
  layer_dropout(rate = 0.1) %>%
  layer_dense(units=800,
              use_bias=T,
              activation = "relu",
              kernel_initializer="random_normal",
              bias_initializer = "zeros") |> 
  layer_dropout(rate = 0.1) |> 
  layer_dense(units=800,
              use_bias=T,
              activation = "relu",
              kernel_initializer="random_normal",
              bias_initializer = "zeros") |> 
   layer_batch_normalization() |> 
  layer_dense(units = 1,
              activation = "linear",
              kernel_initializer = "zeros",
              bias_initializer = "zeros")

model_nn |> 
  compile(optimizer=optimizer_adam( ),
    loss = "mse")

model_nn |> 
  fit(x = x_train, 
      y = y_train,
      validation_data = list(x_cal, y_cal),
      callbacks = callback_early_stopping(
        monitor = "val_loss",
        patience = 20,
        restore_best_weights = T),
      batch_size = 128,
      epochs = 1000)


y_hat_cal <- predict(model_nn, x_cal)
y_hat_test <- predict(model_nn, x_test)

Observing miscalibration

{width=960 110%}

Coverage

{width=960 110%}

Recalibration

recalibrated <- 
  recalibrate(
     pit_values = pit,      # global pit values calculated earlier.
    mse = MSE_cal,         # MSE from calibration set
    yhat_new = y_hat_test, # predictions of test set
    space_cal = x_cal,     # covariates of calibration set
    space_new = x_test,    # covariates of test set
   

   
    type = "local",        # type of calibration
    p_neighbours = 0.08)   # proportion of calibration to use as nearest neighbors

y_hat_rec <- recalibrated$y_samples_calibrated_wt
  • That is it!
  • These new values in y_hat_rec are, by definition, more calibrated than the original ones.

Shall we see?

{width=960 110}

Coverage

{width=960 110}

Real data

Diamonds dataset

After Recalibration

Calibrated using a second hidden layer.

Conclusions and Future Work

  • Effective Visualization of Miscalibration.
  • Advantages related to other packages
    • Focused in regression models
    • Local recalibration
    • Recalibration at intermediate layers.

Future Developments:

  • Integration with other packages, broader input types, cross-validation methods

  • Handle models with arbitrary predictive distributions.

Thank You!

[GitHub](